import sys, os
sys.path.append(os.pardir)  #为了导入父目录中的文件而进行的设定
from dataset.mnist import load_mnist
import numpy as np
from PIL import Image

#下载和读取MNIST数据集
#第一次调用会花费几分钟
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
#输出各个数据的形状
print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)

#显示训练图像的第一张
def img_show(img):
    pil_img = Image.fromarray(np.uint8(img))
    pil_img.show()

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
img = x_train[0]
label = t_train[0]
print(label)
print(img.shape)
img = img.reshape(28, 28)
print(img.shape)
img_show(img)
